#include "cuda_runtime.h"
#include "cuda.h"
#include "cuda_fp16.h"
#include "tensor.h"
#include "embedding_weight.h"

/*
  @brief: input embedding kernel function
  T: 表示embedding的类型
*/
template<typename T>
void launchInputEmbedding(TensorWrapper<int>* input_ids,
                          TensorWrapper<T>* out,
                          EmbeddingWeight<T>* embed_table);